# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import datetime
from app.exceptions import InvalidIdentifierException

from data.utils import populate_models
from data.models.base import Base, ma
from data.models.user import User
from sqlalchemy import Boolean, Column, Integer, String, Text, ForeignKey, Index, DateTime
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import relationship
from sqlalchemy_fulltext import FullText
from lib.vcs_management import getVcsHandler

import marshmallow


class RevisionMixin:
  revision = Column(Integer, nullable=False, default=0)
  active = Column(Boolean, nullable=False, default=True)
  archived_at = Column(DateTime)

  def archive(self):
    self.active = False
    self.archived_at = datetime.datetime.utcnow()


class MarshmallowBase(ma.ModelSchema):

  __abstract__ = True

  class Meta:
    exclude = ('id', 'date_created', 'date_modified')


class VulnerabilityResources(Base):
  __table_args__ = {'schema': 'main'}
  link = Column(String(1000), nullable=False)
  vulnerability_details_id = Column(Integer,
                                    ForeignKey('main.vulnerability.id'))


class CreatorSchema(MarshmallowBase):

  class Meta:
    model = User


class RepositoryFileMarkers(RevisionMixin, Base):
  __tablename__ = 'repository_file_markers'
  row_from = Column(Integer)
  row_to = Column(Integer)
  column_from = Column(Integer)
  column_to = Column(Integer)
  marker_class = Column(String(100), nullable=False)
  repository_file_id = Column(Integer, ForeignKey('repository_files.id'))
  creator_id = Column(Integer, ForeignKey(User.id), nullable=True)
  creator = relationship(User)


class RepositoryFileMarkersSchema(MarshmallowBase):

  class Meta(MarshmallowBase.Meta):
    model = RepositoryFileMarkers

  creator = marshmallow.fields.Nested(CreatorSchema, only=['name'])


class RepositoryFileComments(RevisionMixin, Base):
  __tablename__ = 'repository_file_comments'
  row_from = Column(Integer)
  row_to = Column(Integer)
  text = Column(Text, nullable=False)
  sort_pos = Column(Integer)
  repository_file_id = Column(Integer, ForeignKey('repository_files.id'))
  creator_id = Column(Integer, ForeignKey(User.id), nullable=True)
  creator = relationship(User)


class RepositoryFileCommentsSchema(MarshmallowBase):

  class Meta(MarshmallowBase.Meta):
    model = RepositoryFileComments
    exclude = ['archived_at', 'active']

  creator = marshmallow.fields.Nested(CreatorSchema, only=['name'])


class RepositoryFiles(Base):
  __tablename__ = 'repository_files'
  file_name = Column(String(1000), nullable=False)
  file_path = Column(String(1000), nullable=False)
  file_hash = Column(String(1000), nullable=False)
  # A cached version of all file changes for the given commit.
  file_patch = Column(Text, nullable=False)

  markers = relationship(
      RepositoryFileMarkers,
      backref='repository_file',
      cascade='all, delete-orphan',
      primaryjoin='and_(RepositoryFiles.id==RepositoryFileMarkers.repository_file_id, RepositoryFileMarkers.active==True)',
      single_parent=True)

  comments = relationship(
      RepositoryFileComments,
      backref='repository_file',
      cascade='all, delete-orphan',
      primaryjoin='and_(RepositoryFiles.id==RepositoryFileComments.repository_file_id, RepositoryFileComments.active==True)',
      single_parent=True)
  commit_id = Column(Integer, ForeignKey('vulnerability_git_commits.id'))


class RepositoryFilesSchema(MarshmallowBase):
  # TODO: Add exlude=[] parameter here to skip redundant date and id fields.
  file_patch = ma.Method('get_patch')

  def get_patch(self, obj):
    return 'DEPRECATED'

  markers = ma.Nested(RepositoryFileMarkersSchema, many=True)
  comments = ma.Nested(RepositoryFileCommentsSchema, many=True)

  class Meta(MarshmallowBase.Meta):
    model = RepositoryFiles


class VulnerabilityGitCommits(Base):
  __tablename__ = 'vulnerability_git_commits'

  commit_hash = Column(String(1000), nullable=False)
  _commit_link = Column('commit_link', String(1000), nullable=False)
  repo_name = Column(String(1000), nullable=False)
  repo_owner = Column(String(1000))
  # URL to a *.git Git repository (if applicable).
  _repo_url = Column('repo_url', String(1000))
  vulnerability_details_id = Column(Integer,
                                    ForeignKey('main.vulnerability.id'))
  # Used to store/cache the repository tree files with hashes.
  tree_cache = Column(LONGTEXT())

  repository_files = relationship(
      RepositoryFiles,
      backref='commit',
      cascade='all, delete-orphan',
      single_parent=True,
  )
  # link to comments through RepositoryFiles
  comments = relationship(
      RepositoryFileComments,
      backref='commit',
      secondary=RepositoryFiles.__table__,
      primaryjoin='VulnerabilityGitCommits.id==RepositoryFiles.commit_id',
      secondaryjoin='and_(RepositoryFiles.id==RepositoryFileComments.repository_file_id, RepositoryFileComments.active==True)',
      viewonly=True,
  )
  # link to markers through RepositoryFiles
  markers = relationship(
      RepositoryFileMarkers,
      backref='commit',
      secondary=RepositoryFiles.__table__,
      primaryjoin='VulnerabilityGitCommits.id==RepositoryFiles.commit_id',
      secondaryjoin='and_(RepositoryFiles.id==RepositoryFileMarkers.repository_file_id, RepositoryFileMarkers.active==True)',
      viewonly=True,
  )

  @property
  def num_files(self):
    # TODO: This should be refactored as it is incredibly inefficient.
    #       We should use a count on the database side instead.
    return len(self.repository_files)

  @property
  def num_comments(self):
    # TODO: see comment regarding performance above.
    return len(self.comments)

  @property
  def num_markers(self):
    # TODO: see comment regarding performance above.
    return len(self.markers)

  @property
  def repo_url(self):
    if not self._repo_url:
      # TODO: Refactor this apporach of retrieving github.com urls.
      if self.commit_link and 'github.com' in self.commit_link:
        if self.repo_owner and self.repo_name:
          return 'https://github.com/' + self.repo_owner + '/' + self.repo_name
    return self._repo_url

  @repo_url.setter
  def repo_url(self, repo_url):
    self._repo_url = repo_url

  @property
  def commit_link(self):
    return self._commit_link

  @commit_link.setter
  def commit_link(self, commit_link):
    # TODO: Add commit link sanitization back here. We're currently skipping
    #  it as on object creation (populate) there might be no repo_url set
    #  and the commit_link might be just a VCS UI link to the patch.
    #  We should still always require a separate repository link and commit
    #  hash if it's not a simple Github entry.
    #if not self.repo_url and commit_link:
    # vcs_handler = getVcsHandler(None, commit_link)
    # if not vcs_handler:
    #   raise InvalidIdentifierException('Please provide a valid commit link.')
    if commit_link:
      if not commit_link.startswith('http'):
        raise InvalidIdentifierException('Please provide a valid commit link.')

    self._commit_link = commit_link

  def __init__(self,
               commit_link=None,
               repo_owner=None,
               repo_name=None,
               repo_url=None,
               commit_hash=None):
    self.repo_owner = repo_owner
    self.repo_name = repo_name
    if repo_url:
      vcs_handler = getVcsHandler(None, repo_url)
      if not vcs_handler:
        raise InvalidIdentifierException('Please provide a valid git repo URL.')
      self.repo_url = repo_url
    self.commit_link = commit_link
    self.commit_hash = commit_hash


Index('cl_hash_index', VulnerabilityGitCommits.commit_hash)

# TODO: adjust indices!


class Vulnerability(FullText, Base):
  __fulltext_columns__ = ('comment',)
  __tablename__ = 'vulnerability'
  __table_args__ = {'schema': 'main'}
  comment = Column(Text, nullable=False)
  exploit_exists = Column(Boolean, default=False)
  cve_id = Column(String(255), ForeignKey('cve.nvds.cve_id'), nullable=True)
  creator_id = Column(Integer, ForeignKey(User.id), nullable=True)
  creator = relationship(User)

  resource_links = relationship(
      VulnerabilityResources,
      backref='vulnerability',
      cascade='all, delete-orphan',
      single_parent=True)

  commits = relationship(
      VulnerabilityGitCommits,
      backref='vulnerability',
      cascade='all, delete-orphan',
      single_parent=True)

  nvd = relationship('Nvd', backref='vulnerability', single_parent=True)

  @property
  def master_commit(self):
    # TODO: refactor assumption that the first commit is the "master" one!
    if self.commits:
      return self.commits[0]
    return None

  def __repr__(self):
    return 'Vulnerability Info({:s})'.format(vars(self))

  @classmethod
  def get_by_cve_id(cls, cve_id):
    return cls.query.filter_by(cve_id=cve_id).first()

  @classmethod
  def get_by_commit_hash(cls, commit_hash):
    return cls.query.join(
        Vulnerability.commits,
        aliased=True).filter_by(commit_hash=commit_hash).first()


# must be set after all definitions
__all__ = populate_models(__name__)
